Sequence tagging example

In this example, we implement a named entity tagger using two different approaches: a simple approach where a linear output unit is put on top of an RNN, and a slightly more complex approach where we use a conditional random field to predict the output. This example uses training and validation data from the CoNLL-2003 Shared Task.

NB: you can download the original data here, but the example assumes that the entities have been coded according to the BIO scheme, not the IOB scheme used originally. Please ask Richard if you want to have the processed data that works with this example.

You will need to install pytorch-crf if you want to run the CRF-based tagger.

In [1]:
import torch
from torch import nn
import time
import torchtext
import numpy as np

import random

from collections import defaultdict, Counter

import matplotlib.pyplot as plt

%config InlineBackend.figure_format = 'retina' 
plt.style.use('seaborn')

Reading the data in the CoNLL-2003 format

The following function reads a file represented in the CoNLL-2003 format. In this format, each row corresponds to one token. For each token, there is a word, a part-of-speech tag, a "shallow syntax" label, and the BIO-coded named entity label, separated by whitespace. The sentences are separated by empty lines. Here is an example of a sentence.

United NNP B-NP B-ORG
Nations NNP I-NP I-ORG
official NN I-NP O
Ekeus NNP B-NP B-PER
heads VBZ B-VP O
for IN B-PP O
Baghdad NNP B-NP B-LOC
. . O

The function reads the file in this format and returns a torchtext Dataset, which in turn consists of a number of Example. We will use just the words and the BIO labels, for the input and output respectively.

In [2]:
def read_data(corpus_file, datafields):
    with open(corpus_file, encoding='utf-8') as f:
        examples = []
        words = []
        labels = []
        for line in f:
            line = line.strip()
            if not line:
                examples.append(torchtext.data.Example.fromlist([words, labels], datafields))
                words = []
                labels = []
            else:
                columns = line.split()
                words.append(columns[0])
                labels.append(columns[-1])
        return torchtext.data.Dataset(examples, datafields)

Implementing a tagger based on RNNs and a linear output unit

Our first implementation will be fairly straightforward. We apply an RNN and then a linear output unit to predict the outputs. The following figure illustrates the approach. (The figure is a bit misleading here, because we are predicting BIO labels and not part-of-speech tags, but you get the idea.)

Drawing

High-quality systems that for tasks such as named entity recognition and part-of-speech tagging typically use smarter word representations, for instance by taking the characters into account more carefully. We just use word embeddings.

A small issue to note here is that we don't want the system to spend effort learning to tag the padding tokens. To make the system ignore the padding, we add a large number to the output corresponding to the dummy padding tag. This means that the loss values for these positions will be negligible.

Note that we structure the code a bit differently compared to our previous implementations: we compute the loss in the forward method, while previously we just computed the output in this method. The reason for this change is that the CRF (see below) uses this structure, and we want to keep the implementations compatible. Similarly, the predict method will convert from PyTorch tensors into NumPy arrays, in order to be compatible with the CRF's prediction method.

In [3]:
class RNNTagger(nn.Module):
    
    def __init__(self, text_field, label_field, emb_dim, rnn_size, update_pretrained=False):
        super().__init__()
        
        voc_size = len(text_field.vocab)
        self.n_labels = len(label_field.vocab)       
        
        # Embedding layer. If we're using pre-trained embeddings, copy them
        # into our embedding module.
        self.embedding = nn.Embedding(voc_size, emb_dim)
        if text_field.vocab.vectors is not None:
            self.embedding.weight = torch.nn.Parameter(text_field.vocab.vectors, 
                                                       requires_grad=update_pretrained)

        # RNN layer. We're using a bidirectional GRU with one layer.
        self.rnn = nn.GRU(input_size=emb_dim, hidden_size=rnn_size, 
                          bidirectional=True, num_layers=1)

        # Output layer. As in the example last week, the input will be two times
        # the RNN size since we are using a bidirectional RNN.
        self.top_layer = nn.Linear(2*rnn_size, self.n_labels)
 
        # To deal with the padding positions later, we need to know the
        # encoding of the padding dummy word and the corresponding dummy output tag.
        self.pad_word_id = text_field.vocab.stoi[text_field.pad_token]
        self.pad_label_id = label_field.vocab.stoi[label_field.pad_token]
    
        # Loss function that we will use during training.
        self.loss = torch.nn.CrossEntropyLoss(reduction='sum')
        
    def compute_outputs(self, sentences):
        # The words in the documents are encoded as integers. The shape of the documents
        # tensor is (max_len, n_docs), where n_docs is the number of documents in this batch,
        # and max_len is the maximal length of a document in the batch.

        # First look up the embeddings for all the words in the documents.
        # The shape is now (max_len, n_sentences, emb_dim).        
        embedded = self.embedding(sentences)

        # Apply the RNN.
        # The shape of the RNN output tensor is (max_len, n_sentences, 2*rnn_size).
        rnn_out, _ = self.rnn(embedded)
        
        # Apply the linear output layer.
        # The shape of the output tensor is (max_len, n_sentences, n_labels).
        out = self.top_layer(rnn_out)
        
        # Find the positions where the token is a dummy padding token.
        pad_mask = (sentences == self.pad_word_id).float()

        # For these positions, we add some large number in the column corresponding
        # to the dummy padding label.
        out[:, :, self.pad_label_id] += pad_mask*10000

        return out
                
    def forward(self, sentences, labels):
        # As discussed above, this method first computes the predictions, and then
        # the loss function.
        
        # Compute the outputs. The shape is (max_len, n_sentences, n_labels).
        scores = self.compute_outputs(sentences)
        
        # Flatten the outputs and the gold-standard labels, to compute the loss.
        # The input to this loss needs to be one 2-dimensional and one 1-dimensional tensor.
        scores = scores.view(-1, self.n_labels)
        labels = labels.view(-1)
        return self.loss(scores, labels)

    def predict(self, sentences):
        # Compute the outputs from the linear units.
        scores = self.compute_outputs(sentences)

        # Select the top-scoring labels. The shape is now (max_len, n_sentences).
        predicted = scores.argmax(dim=2)

        # We transpose the prediction to (n_sentences, max_len), and convert it
        # to a NumPy matrix.
        return predicted.t().cpu().numpy()

Implementing a conditional random field tagger

We will now add a CRF layer on top of the linear output units. The CRF will help the model handle the interactions between output tags more consistently, e.g. not mixing up B and I tags of different types. Here is a figure that shows the intuition.

Drawing

The two important methods in the CRF module correspond to the two main algorithm that a CRF needs to implement:

  • decode applies the Viterbi algorithm to compute the highest-scoring sequences.
  • forward applies the forward algorithm to compute the log likelihood of the training set.

Most of the code is identical to the implementation above. The differences are in the forward and predict methods.

In [4]:
from torchcrf import CRF

class RNNCRFTagger(nn.Module):
    
    def __init__(self, text_field, label_field, emb_dim, rnn_size, update_pretrained=False):
        super().__init__()
        
        voc_size = len(text_field.vocab)
        self.n_labels = len(label_field.vocab)       
        
        self.embedding = nn.Embedding(voc_size, emb_dim)
        if text_field.vocab.vectors is not None:
            self.embedding.weight = torch.nn.Parameter(text_field.vocab.vectors, 
                                                       requires_grad=update_pretrained)

        self.rnn = nn.GRU(input_size=emb_dim, hidden_size=rnn_size, 
                          bidirectional=True, num_layers=1)

        self.top_layer = nn.Linear(2*rnn_size, self.n_labels)
 
        self.pad_word_id = text_field.vocab.stoi[text_field.pad_token]
        self.pad_label_id = label_field.vocab.stoi[label_field.pad_token]
    
        self.crf = CRF(self.n_labels)
        
    def compute_outputs(self, sentences):
        embedded = self.embedding(sentences)
        rnn_out, _ = self.rnn(embedded)
        out = self.top_layer(rnn_out)
        
        pad_mask = (sentences == self.pad_word_id).float()
        out[:, :, self.pad_label_id] += pad_mask*10000
        
        return out
                
    def forward(self, sentences, labels):
        # Compute the outputs of the lower layers, which will be used as emission
        # scores for the CRF.
        scores = self.compute_outputs(sentences)

        # We return the loss value. The CRF returns the log likelihood, but we return 
        # the *negative* log likelihood as the loss value.            
        # PyTorch's optimizers *minimize* the loss, while we want to *maximize* the
        # log likelihood.
        return -self.crf(scores, labels)
            
    def predict(self, sentences):
        # Compute the emission scores, as above.
        scores = self.compute_outputs(sentences)

        # Apply the Viterbi algorithm to get the predictions. This implementation returns
        # the result as a list of lists (not a tensor), corresponding to a matrix
        # of shape (n_sentences, max_len).
        return self.crf.decode(scores)

Evaluating the predicted named entities

To evaluate our named entity recognizers, we compare the named entities predicted by the system to the entities in the gold standard. We follow standard practice and compute precision and recall scores, as well as the harmonic mean of the precision and recall, known as the F-score.

Please note that the precision and recall scores are computed with respect to the full named entity spans and labels. To be counted as a correct prediction, the system needs to predict all words in the named entity correctly, and assign the right type of entity label. We don't give any credits to partially correct predictions.

In [5]:
# Convert a list of BIO labels, coded as integers, into spans identified by a beginning, an end, and a label.
# To allow easy comparison later, we store them in a dictionary indexed by the start position.
def to_spans(l_ids, voc):
    spans = {}
    current_lbl = None
    current_start = None
    for i, l_id in enumerate(l_ids):
        l = voc[l_id]

        if l[0] == 'B': 
            # Beginning of a named entity: B-something.
            if current_lbl:
                # If we're working on an entity, close it.
                spans[current_start] = (current_lbl, i)
            # Create a new entity that starts here.
            current_lbl = l[2:]
            current_start = i
        elif l[0] == 'I':
            # Continuation of an entity: I-something.
            if current_lbl:
                # If we have an open entity, but its label does not
                # correspond to the predicted I-tag, then we close
                # the open entity and create a new one.
                if current_lbl != l[2:]:
                    spans[current_start] = (current_lbl, i)
                    current_lbl = l[2:]
                    current_start = i
            else:
                # If we don't have an open entity but predict an I tag,
                # we create a new entity starting here even though we're
                # not following the format strictly.
                current_lbl = l[2:]
                current_start = i
        else:
            # Outside: O.
            if current_lbl:
                # If we have an open entity, we close it.
                spans[current_start] = (current_lbl, i)
                current_lbl = None
                current_start = None
    return spans

# Compares two sets of spans and records the results for future aggregation.
def compare(gold, pred, stats):
    for start, (lbl, end) in gold.items():
        stats['total']['gold'] += 1
        stats[lbl]['gold'] += 1
    for start, (lbl, end) in pred.items():
        stats['total']['pred'] += 1
        stats[lbl]['pred'] += 1
    for start, (glbl, gend) in gold.items():
        if start in pred:
            plbl, pend = pred[start]
            if glbl == plbl and gend == pend:
                stats['total']['corr'] += 1
                stats[glbl]['corr'] += 1

# This function combines the auxiliary functions we defined above.
def evaluate_iob(predicted, gold, label_field, stats):
    # The gold-standard labels are assumed to be an integer tensor of shape
    # (max_len, n_sentences), as returned by torchtext.
    gold_cpu = gold.t().cpu().numpy()
    gold_cpu = list(gold_cpu.reshape(-1))

    # The predicted labels assume the format produced by pytorch-crf, so we
    # assume that they have been converted into a list already.
    # We just flatten the list.
    pred_cpu = [l for sen in predicted for l in sen]
    
    # Compute spans for the gold standard and prediction.
    gold_spans = to_spans(gold_cpu, label_field.vocab.itos)
    pred_spans = to_spans(pred_cpu, label_field.vocab.itos)

    # Finally, update the counts for correct, predicted and gold-standard spans.
    compare(gold_spans, pred_spans, stats)

# Computes precision, recall and F-score, given a dictionary that contains
# the counts of correct, predicted and gold-standard items.
def prf(stats):
    if stats['pred'] == 0:
        return 0, 0, 0
    p = stats['corr']/stats['pred']
    r = stats['corr']/stats['gold']
    if p > 0 and r > 0:
        f = 2*p*r/(p+r)
    else:
        f = 0
    return p, r, f

Training the full system

We structure this a bit differently than in our previous examples, so that we can run the named entity recognizer interactively later. Most of the work is done in the train method, while the tag method can be used to process new examples.

As usual in our examples, the training procedure will create a model, train it for some epochs, and evaluate on the validation set periodically. In most cases, the CRF-based system gives slightly higher evaluation scores than the simple system.

In [6]:
class Tagger:
    
    def __init__(self, lower):
        self.TEXT = torchtext.data.Field(init_token='<bos>', eos_token='<eos>', sequential=True, lower=lower)
        self.LABEL = torchtext.data.Field(init_token='<bos>', eos_token='<eos>', sequential=True, unk_token=None)
        self.fields = [('text', self.TEXT), ('label', self.LABEL)]
        self.device = 'cuda'
        
    def tag(self, sentences):
        # This method applies the trained model to a list of sentences.
        
        # First, create a torchtext Dataset containing the sentences to tag.
        examples = []
        for sen in sentences:
            labels = ['?']*len(sen) # placeholder
            examples.append(torchtext.data.Example.fromlist([sen, labels], self.fields))
        dataset = torchtext.data.Dataset(examples, self.fields)
        
        iterator = torchtext.data.Iterator(
            dataset,
            device=self.device,
            batch_size=64,
            repeat=False,
            train=False,
            sort=False)
        
        # Apply the trained model to all batches.
        out = []
        self.model.eval()
        with torch.no_grad():
            for batch in iterator:
                # Call the model's predict method. This returns a list of NumPy matrix
                # containing the integer-encoded tags for each sentence.
                predicted = self.model.predict(batch.text)

                # Convert the integer-encoded tags to tag strings.
                for tokens, pred_sen in zip(sentences, predicted):
                    out.append([self.LABEL.vocab.itos[pred_id] for _, pred_id in zip(tokens, pred_sen[1:])])
        return out
                
    def train(self):
        # Read training and validation data according to the predefined split.
        train_examples = read_data('data/eng.train.iob', self.fields)
        valid_examples = read_data('data/eng.valid.iob', self.fields)

        # Count the number of words and sentences.
        n_tokens_train = 0
        n_sentences_train = 0
        for ex in train_examples:
            n_tokens_train += len(ex.text) + 2
            n_sentences_train += 1
        n_tokens_valid = 0       
        for ex in valid_examples:
            n_tokens_valid += len(ex.text)

        # Load the pre-trained embeddings that come with the torchtext library.
        use_pretrained = True
        if use_pretrained:
            print('We are using pre-trained word embeddings.')
            self.TEXT.build_vocab(train_examples, vectors="glove.840B.300d")
        else:  
            print('We are training word embeddings from scratch.')
            self.TEXT.build_vocab(train_examples, max_size=5000)
        self.LABEL.build_vocab(train_examples)
    
        # Create one of the models defined above.
        #self.model = RNNTagger(self.TEXT, self.LABEL, emb_dim=300, rnn_size=128, update_pretrained=False)
        self.model = RNNCRFTagger(self.TEXT, self.LABEL, emb_dim=300, rnn_size=128, update_pretrained=False)
    
        self.model.to(self.device)
    
        batch_size = 1024
        n_batches = np.ceil(n_sentences_train / batch_size)

        mean_n_tokens = n_tokens_train / n_batches

        train_iterator = torchtext.data.BucketIterator(
            train_examples,
            device=self.device,
            batch_size=batch_size,
            sort_key=lambda x: len(x.text),
            repeat=False,
            train=True,
            sort=True)

        valid_iterator = torchtext.data.BucketIterator(
            valid_examples,
            device=self.device,
            batch_size=64,
            sort_key=lambda x: len(x.text),
            repeat=False,
            train=False,
            sort=True)
    
        train_batches = list(train_iterator)
        valid_batches = list(valid_iterator)

        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.01, weight_decay=1e-5)

        n_labels = len(self.LABEL.vocab)

        history = defaultdict(list)    
        
        n_epochs = 25
        
        for i in range(1, n_epochs + 1):

            t0 = time.time()

            loss_sum = 0

            self.model.train()
            for batch in train_batches:
                
                # Compute the output and loss.
                loss = self.model(batch.text, batch.label) / mean_n_tokens
                
                optimizer.zero_grad()            
                loss.backward()
                optimizer.step()
                loss_sum += loss.item()

            train_loss = loss_sum / n_batches
            history['train_loss'].append(train_loss)

            # Evaluate on the validation set.
            if i % 1 == 0:
                stats = defaultdict(Counter)

                self.model.eval()
                with torch.no_grad():
                    for batch in valid_batches:
                        # Predict the model's output on a batch.
                        predicted = self.model.predict(batch.text)                   
                        # Update the evaluation statistics.
                        evaluate_iob(predicted, batch.label, self.LABEL, stats)
            
                # Compute the overall F-score for the validation set.
                _, _, val_f1 = prf(stats['total'])
                
                history['val_f1'].append(val_f1)
            
                t1 = time.time()
                print(f'Epoch {i}: train loss = {train_loss:.4f}, val f1: {val_f1:.4f}, time = {t1-t0:.4f}')
           
        # After the final evaluation, we print more detailed evaluation statistics, including
        # precision, recall, and F-scores for the different types of named entities.
        print()
        print('Final evaluation on the validation set:')
        p, r, f1 = prf(stats['total'])
        print(f'Overall: P = {p:.4f}, R = {r:.4f}, F1 = {f1:.4f}')
        for label in stats:
            if label != 'total':
                p, r, f1 = prf(stats[label])
                print(f'{label:4s}: P = {p:.4f}, R = {r:.4f}, F1 = {f1:.4f}')
        
        plt.plot(history['train_loss'])
        plt.plot(history['val_f1'])
        plt.legend(['training loss', 'validation F-score'])

tagger = Tagger(lower=False)
tagger.train()
We are using pre-trained word embeddings.
Epoch 1: train loss = 0.5187, val f1: 0.6162, time = 2.4319
Epoch 2: train loss = 0.2222, val f1: 0.7028, time = 2.2781
Epoch 3: train loss = 0.1429, val f1: 0.7437, time = 2.1386
Epoch 4: train loss = 0.1017, val f1: 0.7819, time = 2.2807
Epoch 5: train loss = 0.0773, val f1: 0.7529, time = 2.1722
Epoch 6: train loss = 0.0663, val f1: 0.8015, time = 2.1784
Epoch 7: train loss = 0.0589, val f1: 0.8388, time = 2.1753
Epoch 8: train loss = 0.0495, val f1: 0.8213, time = 2.1998
Epoch 9: train loss = 0.0414, val f1: 0.8511, time = 2.3104
Epoch 10: train loss = 0.0351, val f1: 0.8514, time = 2.2913
Epoch 11: train loss = 0.0324, val f1: 0.7937, time = 2.2949
Epoch 12: train loss = 0.0304, val f1: 0.8560, time = 2.2316
Epoch 13: train loss = 0.0243, val f1: 0.8607, time = 2.1826
Epoch 14: train loss = 0.0232, val f1: 0.8617, time = 2.2301
Epoch 15: train loss = 0.0213, val f1: 0.8595, time = 2.1853
Epoch 16: train loss = 0.0198, val f1: 0.8726, time = 2.2425
Epoch 17: train loss = 0.0154, val f1: 0.8575, time = 2.2749
Epoch 18: train loss = 0.0134, val f1: 0.8405, time = 2.3226
Epoch 19: train loss = 0.0126, val f1: 0.8702, time = 2.2944
Epoch 20: train loss = 0.0123, val f1: 0.8656, time = 2.2954
Epoch 21: train loss = 0.0127, val f1: 0.8699, time = 2.3179
Epoch 22: train loss = 0.0118, val f1: 0.8747, time = 2.3110
Epoch 23: train loss = 0.0133, val f1: 0.8330, time = 2.2839
Epoch 24: train loss = 0.0132, val f1: 0.8036, time = 2.3062
Epoch 25: train loss = 0.0145, val f1: 0.8514, time = 2.2759

Final evaluation on the validation set:
Overall: P = 0.8488, R = 0.8541, F1 = 0.8514
LOC : P = 0.9133, R = 0.9113, F1 = 0.9123
ORG : P = 0.9075, R = 0.7465, F1 = 0.8191
MISC: P = 0.8841, R = 0.8026, F1 = 0.8414
PER : P = 0.7525, R = 0.9012, F1 = 0.8202

Analysis of some examples

We'll run the trained named entity recognizer interactively and consider the system's behavior in a few examples. Please note that the system's output can vary a bit, depending on which model you trained, as well as on randomness in the training process.

First, we create a utility function that takes a single sentence, runs the trained system, and prints the words and output tags line by line.

In [7]:
def print_tags(sentence):
    tokens = sentence.split()
    tags = tagger.tag([tokens])[0]
    for token, tag in zip(tokens, tags):
        print(f'{token:12s}{tag}')

Here is a fairly straightforward example, that in most cases will be tagged correctly (in particular if you are using the CRF-based system). In particular, the system needs to recognize that the word Gothenburg needs to be tagged differently, depending on the context.

In [8]:
print_tags('John Johnson was born in Moscow , lives in Gothenburg , and works for Chalmers Technical University and the University of Gothenburg .')
John        B-PER
Johnson     I-PER
was         O
born        O
in          O
Moscow      B-LOC
,           O
lives       O
in          O
Gothenburg  B-LOC
,           O
and         O
works       O
for         O
Chalmers    B-ORG
Technical   I-ORG
University  I-ORG
and         O
the         O
University  B-ORG
of          I-ORG
Gothenburg  I-ORG
.           O

It is worth noting that the system has some robustness with words it hasn't observed before. In most cases, it will be able to pick up the pattern that the word following John should also be included in a multi-word person name, and that the place where someone was born is probably a location.

In [9]:
print_tags('John XYZXYZABC was born in XYZABC .')
John        B-PER
XYZXYZABC   I-PER
was         O
born        O
in          O
XYZABC      B-LOC
.           O

The following example is typically tagged incorrectly. We would expect Paris Hilton to be tagged as a person, but the system confuses this name with the French capital.

In [10]:
print_tags('Paris Hilton lives in New York .')
Paris       B-LOC
Hilton      O
lives       O
in          O
New         B-LOC
York        I-LOC
.           O

Here is another example of an abiguous term that is most often handled correctly: New York is a part of an organization name in the first case, and a location name in the second case.

In [11]:
print_tags('New York Stock Exchange is in New York .')
New         B-ORG
York        I-ORG
Stock       I-ORG
Exchange    I-ORG
is          O
in          O
New         B-LOC
York        I-LOC
.           O
In [ ]: